Goto

Collaborating Authors

 ensemble distribution distillation



Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets

Neural Information Processing Systems

Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs can be prohibitively high. Ensemble Distribution Distillation (EnD$^2$) is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members' output distributions via the maximum likelihood criterion. Although theoretically principled, this work shows that the criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. Specifically, we show that for the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes. Hence during training the model focuses on the distribution of the ensemble tail-class probabilities rather than the probability of the correct and closely related classes. We propose a new training objective which minimizes the reverse KL-divergence to a \emph{Proxy-Dirichlet} target derived from the ensemble. This loss resolves the gradient issues of EnD$^2$, as we demonstrate both theoretically and empirically on the ImageNet, LibriSpeech, and WMT17 En-De datasets containing 1000, 5000, and 40,000 classes, respectively.


A Loss Derivation

Neural Information Processing Systems

In this section we provide a more detailed derivation of the proposed loss function (Equation 17). We make use of the fact that the negative entropy of the Dirichlet distribution is equivalent to the reverse KL-divergence to a flat Dirichlet, up to an additive constant which doesn't depend on the We resolved this by using a single LayerNorm layer just before the final output layer. We suspect that a more numerically stable implementation of the loss would not require LayerNorm. Additionally, we examined the models' median precisions ( Let's examine how to emulate an ensemble of auto-regressive models using Prior Networks. Measures of Uncertainty Let's examine how given this model we can obtain measures of sequence-level total and knowledge uncertainty.




Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets

Neural Information Processing Systems

Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs can be prohibitively high. Ensemble Distribution Distillation (EnD 2) is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members' output distributions via the maximum likelihood criterion. Although theoretically principled, this work shows that the criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. Specifically, we show that for the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes.


Scaling Ensemble Distribution Distillation to Many Classes with Proxy Targets

Ryabinin, Max, Malinin, Andrey, Gales, Mark

arXiv.org Artificial Intelligence

Ensembles of machine learning models yield improved system performance as well as robust and interpretable uncertainty estimates; however, their inference costs may often be prohibitively high. Ensemble Distribution Distillation is an approach that allows a single model to efficiently capture both the predictive performance and uncertainty estimates of an ensemble. For classification, this is achieved by training a Dirichlet distribution over the ensemble members' output distributions via the maximum likelihood criterion. Although theoretically principled, this criterion exhibits poor convergence when applied to large-scale tasks where the number of classes is very high. In our work, we analyze this effect and show that for the Dirichlet log-likelihood criterion classes with low probability induce larger gradients than high-probability classes. This forces the model to focus on the distribution of the ensemble tail-class probabilities. We propose a new training objective which minimizes the reverse KL-divergence to a Proxy-Dirichlet target derived from the ensemble. This loss resolves the gradient issues of Ensemble Distribution Distillation, as we demonstrate both theoretically and empirically on the ImageNet and WMT17 En-De datasets containing 1000 and 40,000 classes, respectively.


Ensemble Distribution Distillation

Malinin, Andrey, Mlodozeniec, Bruno, Gales, Mark

arXiv.org Machine Learning

Ensemble of Neural Network (NN) models are known to yield improvements in accuracy. Furthermore, they have been empirically shown to yield robust measures of uncertainty, though without theoretical guarantees. However, ensembles come at high computational and memory cost, which may be prohibitive for certain application. There has been significant work done on the distillation of an ensemble into a single model. Such approaches decrease computational cost and allow a single model to achieve accuracy comparable to that of an ensemble. However, information about the \emph{diversity} of the ensemble, which can yield estimates of \emph{knowledge uncertainty}, is lost. Recently, a new class of models, called Prior Networks, has been proposed, which allows a single neural network to explicitly model a distribution over output distributions, effectively emulating an ensemble. In this work ensembles and Prior Networks are combined to yield a novel approach called \emph{Ensemble Distribution Distillation} (EnD$^2$), which allows distilling an ensemble into a single Prior Network. This allows a single model to retain both the improved classification performance as well as measures of diversity of the ensemble. In this initial investigation the properties of EnD$^2$ have been investigated and confirmed on an artificial dataset.